from archs.cifar_resnet import resnet as resnet_cifar
from archs.mnist_lenet import LeNet
from datasets import get_normalize_layer, get_input_center_layer
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
from torchvision.models.resnet import resnet50

ARCHITECTURES = ["lenet", "resnet110"]

def get_architecture(arch: str, dataset: str) -> torch.nn.Module:
    """ Return a neural network (with random weights)

    :param arch: the architecture - should be in the ARCHITECTURES list above
    :param dataset: the dataset - should be in the datasets.DATASETS list
    :return: a Pytorch module
    """
    if arch == "resnet110":
        model = resnet_cifar(depth=110, num_classes=10).cuda()
        normalize_layer = get_normalize_layer(dataset)
        return torch.nn.Sequential(normalize_layer, model)
    elif arch == "lenet":
        model = LeNet().cuda()
        return model
    
